Skip to content

[Pallas] Switch gather to jnp.take_along_axis (for JAX issue filing)#2061

Draft
AmesingFlank wants to merge 1 commit into
AmesingFlank/stack/25from
AmesingFlank/stack/27
Draft

[Pallas] Switch gather to jnp.take_along_axis (for JAX issue filing)#2061
AmesingFlank wants to merge 1 commit into
AmesingFlank/stack/25from
AmesingFlank/stack/27

Conversation

@AmesingFlank
Copy link
Copy Markdown
Contributor

@AmesingFlank AmesingFlank commented Apr 20, 2026

Stacked PRs:


[Pallas] Switch gather to jnp.take_along_axis (for JAX issue filing)

This version uses jnp.take_along_axis which is the natural JAX equivalent
of torch.gather. It works in interpret mode but fails on real TPU due to
a limitation in Mosaic's lax.gather lowering rule which requires
indices.shape == input.shape + (1,).

Co-Authored-By: Claude Sonnet 4 noreply@anthropic.com

@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/26 branch from 1a3b7f5 to 696b52e Compare April 20, 2026 21:39
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/27 branch from 3a1edd6 to 15cbc51 Compare April 20, 2026 21:39
@meta-cla meta-cla Bot added the CLA Signed This label is managed by the Meta Open Source bot. label Apr 20, 2026
@AmesingFlank AmesingFlank marked this pull request as draft April 20, 2026 21:42
AmesingFlank added a commit that referenced this pull request Apr 20, 2026
This version uses jnp.take_along_axis which is the natural JAX equivalent
of torch.gather. It works in interpret mode but fails on real TPU due to
a limitation in Mosaic's lax.gather lowering rule which requires
indices.shape == input.shape + (1,).

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>

stack-info: PR: #2061, branch: AmesingFlank/stack/27
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/26 to main April 20, 2026 22:00
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/27 branch from 15cbc51 to b1ca465 Compare April 20, 2026 22:00
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/26 April 20, 2026 22:00
AmesingFlank added a commit that referenced this pull request Jun 4, 2026
This version uses jnp.take_along_axis which is the natural JAX equivalent
of torch.gather. It works in interpret mode but fails on real TPU due to
a limitation in Mosaic's lax.gather lowering rule which requires
indices.shape == input.shape + (1,).

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>

stack-info: PR: #2061, branch: AmesingFlank/stack/27
AmesingFlank added a commit that referenced this pull request Jun 4, 2026
This version uses jnp.take_along_axis which is the natural JAX equivalent
of torch.gather. It works in interpret mode but fails on real TPU due to
a limitation in Mosaic's lax.gather lowering rule which requires
indices.shape == input.shape + (1,).

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>

stack-info: PR: #2061, branch: AmesingFlank/stack/27
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/26 to main June 4, 2026 01:36
AmesingFlank added a commit that referenced this pull request Jun 4, 2026
This version uses jnp.take_along_axis which is the natural JAX equivalent
of torch.gather. It works in interpret mode but fails on real TPU due to
a limitation in Mosaic's lax.gather lowering rule which requires
indices.shape == input.shape + (1,).

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>

stack-info: PR: #2061, branch: AmesingFlank/stack/27
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/27 branch from b1ca465 to 9550c41 Compare June 4, 2026 01:36
AmesingFlank added a commit that referenced this pull request Jun 4, 2026
This version uses jnp.take_along_axis which is the natural JAX equivalent
of torch.gather. It works in interpret mode but fails on real TPU due to
a limitation in Mosaic's lax.gather lowering rule which requires
indices.shape == input.shape + (1,).

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>

stack-info: PR: #2061, branch: AmesingFlank/stack/27
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/27 branch from 9550c41 to 3182883 Compare June 4, 2026 01:37
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/26 June 4, 2026 01:37
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/26 to main June 4, 2026 01:50
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/26 June 4, 2026 01:50
AmesingFlank added a commit that referenced this pull request Jun 4, 2026
This version uses jnp.take_along_axis which is the natural JAX equivalent
of torch.gather. It works in interpret mode but fails on real TPU due to
a limitation in Mosaic's lax.gather lowering rule which requires
indices.shape == input.shape + (1,).

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>

stack-info: PR: #2061, branch: AmesingFlank/stack/27
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/26 to main June 4, 2026 01:52
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/27 branch from 3182883 to 921bcaf Compare June 4, 2026 01:52
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/26 June 4, 2026 01:52
AmesingFlank added a commit that referenced this pull request Jun 4, 2026
This version uses jnp.take_along_axis which is the natural JAX equivalent
of torch.gather. It works in interpret mode but fails on real TPU due to
a limitation in Mosaic's lax.gather lowering rule which requires
indices.shape == input.shape + (1,).

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>

stack-info: PR: #2061, branch: AmesingFlank/stack/27
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/26 to main June 4, 2026 01:59
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/27 branch from 921bcaf to 33543af Compare June 4, 2026 01:59
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/26 June 4, 2026 01:59
AmesingFlank added a commit that referenced this pull request Jun 4, 2026
This version uses jnp.take_along_axis which is the natural JAX equivalent
of torch.gather. It works in interpret mode but fails on real TPU due to
a limitation in Mosaic's lax.gather lowering rule which requires
indices.shape == input.shape + (1,).

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>

stack-info: PR: #2061, branch: AmesingFlank/stack/27
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/26 to main June 4, 2026 02:06
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/27 branch from 33543af to 4e3bc80 Compare June 4, 2026 02:06
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/26 June 4, 2026 02:07
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/26 to main June 4, 2026 02:12
AmesingFlank added a commit that referenced this pull request Jun 4, 2026
This version uses jnp.take_along_axis which is the natural JAX equivalent
of torch.gather. It works in interpret mode but fails on real TPU due to
a limitation in Mosaic's lax.gather lowering rule which requires
indices.shape == input.shape + (1,).

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>

stack-info: PR: #2061, branch: AmesingFlank/stack/27
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/27 branch from 4e3bc80 to f32d40d Compare June 4, 2026 02:12
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/25 June 4, 2026 02:12
This version uses jnp.take_along_axis which is the natural JAX equivalent
of torch.gather. It works in interpret mode but fails on real TPU due to
a limitation in Mosaic's lax.gather lowering rule which requires
indices.shape == input.shape + (1,).

Co-Authored-By: Claude Sonnet 4 <noreply@anthropic.com>

stack-info: PR: #2061, branch: AmesingFlank/stack/27
@AmesingFlank AmesingFlank changed the base branch from AmesingFlank/stack/25 to main June 4, 2026 05:11
@AmesingFlank AmesingFlank force-pushed the AmesingFlank/stack/27 branch from f32d40d to 9ef6bb5 Compare June 4, 2026 05:11
@AmesingFlank AmesingFlank changed the base branch from main to AmesingFlank/stack/25 June 4, 2026 05:11
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

CLA Signed This label is managed by the Meta Open Source bot.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant